import tensorflow as tf
import numpy as np
from tensorflow.keras import layers, Sequential
from correlation import correlation

class DispNet(tf.keras.Model):
    def __init__(self):
        super(DispNet, self).__init__()
        self.conv1a = layers.Conv2D(64,(7,7),padding = 'same',activation = 'relu')
        self.pool1a = layers.MaxPool2D((2,2))
        self.conv3a = layers.Conv2D(128,(5,5),padding = 'same',activation = 'relu')
        self.pool3a = layers.MaxPool2D((2,2))
        self.conv17a = layers.Conv2D(256,(5,5),padding = 'same',activation = 'relu')
        self.pool8a = layers.MaxPool2D((2,2))
        self.conv1b = layers.Conv2D(64,(7,7),padding = 'same',activation = 'relu')
        self.pool1b = layers.MaxPool2D((2,2))
        self.conv3b = layers.Conv2D(128,(5,5),padding = 'same',activation = 'relu')
        self.pool3b = layers.MaxPool2D((2,2))
        self.conv17b = layers.Conv2D(256,(5,5),padding = 'same',activation = 'relu')
        self.pool8b = layers.MaxPool2D((2,2))
        
        self.corr = layers.DepthwiseConv2D(kernel_size = (1,1), strides=(1,1), padding='valid', depth_multiplier=1)
        self.conva = layers.Conv2D(32,(1,1),padding = 'same',activation = 'relu')
        
        self.conv4 = layers.Conv2D(256,(3,3),padding = 'same',activation = 'relu')
        self.conv9 = layers.Conv2D(512,(3,3),padding = 'same',activation = 'relu')
        self.pool5 = layers.MaxPool2D((2,2))
        self.conv10 = layers.Conv2D(512,(3,3),padding = 'same',activation = 'relu')
        self.conv11 = layers.Conv2D(512,(3,3),padding = 'same',activation = 'relu')
        self.pool6 = layers.MaxPool2D((2,2))
        self.conv12 = layers.Conv2D(512,(3,3),padding = 'same',activation = 'relu')
        self.conv13 = layers.Conv2D(1024,(3,3),padding = 'same',activation = 'relu')
        self.pool7 = layers.MaxPool2D((2,2))
        self.conv14 = layers.Conv2D(1024,(3,3),padding = 'same',activation = 'relu')
        self.conv18 = layers.Conv2D(1,(3,3),padding = 'same',activation = 'relu')
        self.up1 = layers.UpSampling2D((2,2))
        self.deconv4 = layers.Conv2DTranspose(512,(4,4),strides=(2, 2),padding = 'same',activation = 'relu')
        self.bn1 = layers.BatchNormalization()
        self.conv19 = layers.Conv2D(512,(3,3),padding = 'same',activation = 'relu')
        self.conv20 = layers.Conv2D(1,(3,3),padding = 'same',activation = 'relu')
        self.up2 = layers.UpSampling2D((2,2))
        self.deconv5 = layers.Conv2DTranspose(256,(4,4),strides=(2, 2),padding = 'same',activation = 'relu')
        self.bn2 = layers.BatchNormalization()
        self.conv21 = layers.Conv2D(256,(3,3),padding = 'same',activation = 'relu')
        self.deconv24 = layers.Conv2DTranspose(128,(4,4),strides=(2, 2),padding = 'same',activation = 'relu')
        self.bn3 = layers.BatchNormalization()
        self.conv22 = layers.Conv2D(1,(3,3),padding = 'same',activation = 'relu')
        self.up3 = layers.UpSampling2D((2,2))
        self.conv23 = layers.Conv2D(128,(3,3),padding = 'same',activation = 'relu')
        self.conv24 = layers.Conv2D(1,(3,3),padding = 'same',activation = 'relu')
        self.up4 = layers.UpSampling2D((2,2))
        self.deconv7 = layers.Conv2DTranspose(64,(4,4),strides=(2, 2),padding = 'same',activation = 'relu')
        self.bn4 = layers.BatchNormalization()
        self.conv25 = layers.Conv2D(64,(3,3),padding = 'same',activation = 'relu')
        self.conv26 = layers.Conv2D(1,(3,3),padding = 'same',activation = 'relu')
        self.up5 = layers.UpSampling2D((2,2))
        self.deconv8 = layers.Conv2DTranspose(32,(4,4),strides=(2, 2),padding = 'same',activation = 'relu')
        self.bn5 = layers.BatchNormalization()
        self.conv27 = layers.Conv2D(32,(3,3),padding = 'same',activation = 'relu')
        self.conv28 = layers.Conv2D(1,(3,3),padding = 'same',activation = 'relu')
    def call(self, left, right, training=None):
        c1a = self.conv1a(left)
        p1a = self.pool1a(c1a)
        c3a = self.conv3a(p1a)
        p3a = self.pool3a(c3a)
        c17a = self.conv17a(p3a)
        p8a = self.pool8a(c17a)
        c1b = self.conv1b(right)
        p1b = self.pool1b(c1b)
        c3b = self.conv3b(p1b)
        p3b = self.pool3b(c3b)
        c17b = self.conv17b(p3b)
        p8b = self.pool8b(c17b)
        
#         c = tf.concat([p8a, p8b],axis = 3)
#         cc = self.corr(c)

        cc = correlation(p8a, p8b)
        cc = tf.nn.leaky_relu(cc, 0.1)
        ca = self.conva(p8a)
        net = tf.concat([ca,cc],axis = 3)
        
        c4 = self.conv4(net)
        c9 = self.conv9(c4)
        p5 = self.pool5(c9)
        c10 = self.conv10(p5)
        c11 = self.conv11(c10)
        p6 = self.pool6(c11)
        c12 = self.conv12(p6)
        c13 = self.conv13(c12)
        p7 = self.pool7(c13)
        c14 = self.conv14(p7)
        c18 = self.conv18(c14)
        u1 = self.up1(c18)
        d4 = self.deconv4(c14)
        b1 = self.bn1(d4)
        merge_2 = tf.concat([c12,b1,u1],axis = 3)
        c19 = self.conv19(merge_2)
        c20 = self.conv20(c19)
        u2 = self.up2(c20)
        d5 = self.deconv5(c19)
        b2 = self.bn2(d5)
        merge_3 = tf.concat([c10,b2,u2],axis = 3)
        c21 = self.conv21(merge_3)
        d24 = self.deconv24(c21)
        b3 = self.bn3(d24)
        c22 = self.conv22(c21)
        u3 = self.up3(c22)
        merge_4 = tf.concat([c4,b3,u3],axis = 3)
        c23 = self.conv23(merge_4)
        c24 = self.conv24(c23)
        u4 = self.up4(c24)
        d7 = self.deconv7(c23)
        b4 = self.bn4(d7)
#         print(p3a.shape,b4.shape,u4.shape)
        merge_5 = layers.concatenate([p3a,b4,u4],axis = 3)#([p3b,b4,u4],axis = 3)
        c25 = self.conv25(merge_5)
        c26 = self.conv26(c25)
        u5 = self.up5(c26)
        d8 = self.deconv8(c25)
        b5 = self.bn5(d8)
        merge_6 = tf.concat([p1a,b5,u5],axis = 3)#([p1b,b5,u5],axis = 3)
        c27 = self.conv27(merge_6)
        out = self.conv28(c27)
        return out,c26,c24,c22,c20
